import argparse
import logging
import os
import concurrent.futures

import numpy as np
import pandas as pd
import torch
from openpyxl import Workbook

from common import create_folders, make_env
from hyperparameters import get_hyperparameters
from sac import SACGRU, SAC

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Constants
LOG_DIR = "logs"
STEP_RANGES = [1] + list(range(2, 32, 2))

# Ensure the logs directory exists
os.makedirs(LOG_DIR, exist_ok=True)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(LOG_DIR, "evaluation.log")),
        logging.StreamHandler()
    ]
)


def create_excel_if_not_exists(file_path):
    """Create an Excel file if it does not already exist."""
    if not os.path.exists(file_path):
        workbook = Workbook()
        sheet = workbook.active
        sheet.title = "Sheet"
        workbook.save(filename=file_path)
        logging.info(f"New workbook created and saved as {file_path}")
    else:
        logging.info(f"Workbook already exists at {file_path}")


def setup_environment(env_name, seed):
    """Set up the environment based on the environment name and seed."""
    return make_env(env_name, seed)


def evaluate_policy(policy, eval_env, steps_list, action_dim):
    """Evaluate the policy over a list of steps and return the results."""
    results = []
    for steps in steps_list:
        total_rewards = 0
        logging.info(f"Evaluating with steps: {steps}")
        for _ in range(10):
            eval_state, eval_done = eval_env.reset(), False
            eval_prev_action = torch.zeros(action_dim)
            while not eval_done:
                state_tensor = torch.FloatTensor(eval_state.reshape(1, -1)).to(device)
                prev_action_tensor = torch.FloatTensor(eval_prev_action.reshape(1, -1)).to(device)
                _, _, eval_actions = policy.policy.sample(state_tensor, prev_action_tensor, steps, True)
                eval_actions = eval_actions.cpu().data.numpy()[0]

                for step in range(steps):
                    eval_action = eval_actions[step] if steps > 1 else eval_actions
                    eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)
                    eval_state = eval_next_state
                    eval_prev_action = eval_action
                    total_rewards += eval_reward
                    if eval_done:
                        break
        avg_reward = total_rewards / 10
        logging.info(f"Evaluation reward: {avg_reward:.3f}")
        results.append((steps, avg_reward))
    return results


def get_env_and_policy(seed, env_name, policy_type, train_steps=None):
    """Initialize the environment and policy based on the given parameters."""
    hyperparams = get_hyperparameters(env_name, 'SAC')

    env = make_env(env_name, seed + 100)
    state_dim = env.observation_space.shape[0]

    policy_kwargs = {
        "gamma": hyperparams['discount'],
        "tau": hyperparams['tau'],
        "alpha": hyperparams['alpha'],
        "policy_type": "Gaussian",
        "hidden_size": hyperparams['hidden_size'],
        "target_update_interval": hyperparams['target_update_interval'],
        "automatic_entropy_tuning": True,
        "lr": hyperparams['lr'],
    }

    if policy_type == 'GRU':
        policy_kwargs["steps"] = train_steps
        policy = SACGRU(state_dim, env.action_space, **policy_kwargs)
        file_name = f"SAC_{policy_type}_{env_name}_{seed}_True_{train_steps}_4_best"
    else:
        policy = SAC(state_dim, env.action_space, **policy_kwargs)
        file_name = f"SAC_{env_name}_{seed}_True_best"

    policy.load_checkpoint(f"./models/{file_name}")

    return env, policy


def evaluate_episode(policy, env_name, steps, algo='gru'):
    """Evaluate a single episode using the given policy and environment."""
    eval_env = make_env(env_name, 100)
    eval_state, eval_done = eval_env.reset(), False
    action_dim = eval_env.action_space.shape[0]
    eval_prev_action = torch.zeros(action_dim)
    episode_reward = 0

    while not eval_done:
        eval_actions = policy.select_action(eval_state, eval_prev_action, steps, evaluate=True) if algo == 'gru' else policy.select_action(eval_state, evaluate=True)
        for step in range(steps):
            eval_action = eval_actions[step] if steps > 1 and algo == 'gru' else eval_actions
            eval_next_state, eval_reward, eval_done, _ = eval_env.step(eval_action)
            eval_state = eval_next_state
            eval_prev_action = eval_action
            episode_reward += eval_reward
            if eval_done:
                break

    return episode_reward


def parallel_evaluate_policy(policy, env_name, steps, episodes=10, algo='gru'):
    """Evaluate the policy in parallel over multiple episodes."""
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = [executor.submit(evaluate_episode, policy, env_name, steps, algo) for _ in range(episodes)]
        results = [future.result() for future in concurrent.futures.as_completed(futures)]
    return np.mean(results)


def eval(seeds=[0, 1, 2, 3, 4], env_name='LunarLanderContinuous-v2', steps=[2, 4, 8, 16]):
    """Main function to evaluate the policy. Model is trained and evaluated inside."""
    logging.info("---------------------------------------")
    logging.info(f"Env: {env_name}, Seed: {seeds}, Steps: {steps}")
    logging.info("---------------------------------------")
    create_folders()

    df = pd.DataFrame(columns=['seed', 'train_steps', 'steps', 'avg_reward', 'algorithm', 'env'])

    for seed in seeds:
        for algo in ['GRU', '']:
            if algo == 'GRU':
                for train_steps in steps:
                    env, policy = get_env_and_policy(seed, env_name, algo, train_steps)
                    for step_range in STEP_RANGES:
                        avg_reward = parallel_evaluate_policy(policy, env_name, step_range, 10, algo.lower())
                        df.loc[len(df)] = [seed, train_steps, step_range, avg_reward, algo, env_name]
            else:
                env, policy = get_env_and_policy(seed, env_name, algo)
                for step_range in STEP_RANGES:
                    avg_reward = parallel_evaluate_policy(policy, env_name, step_range, 10, algo.lower())
                    df.loc[len(df)] = [seed, None, step_range, avg_reward, algo, env_name]

    df.to_csv('eval.csv', mode='a', index=False, header=None)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", default="LunarLanderContinuous-v2", help="Environment name")
    parser.add_argument("--seeds", nargs='+', default=[0, 1, 2, 3, 4], type=int, help="Seeds to evaluate")
    parser.add_argument("--steps", nargs='+', default=[2, 4, 8, 16], type=int, help="Steps to evaluate for SRL")

    args = vars(parser.parse_args())
    logging.info('Command-line argument values:')
    for key, value in args.items():
        logging.info(f'- {key} : {value}')

    eval(**args)
